{-# LANGUAGE CPP, TypeFamilies, NoMonomorphismRestriction, FlexibleInstances, ScopedTypeVariables #-}
{- Copyright JP Bernardy 2008 -}

-- | Generic syntax tree handling functions
module Yi.Syntax.Tree (IsTree(..), toksAfter, allToks, tokAtOrBefore, toksInRegion,
                       sepBy, sepBy1,
                       getLastOffset, getFirstOffset,
                       getFirstElement, getLastElement,
                       getLastPath,
                       getAllSubTrees,
                       tokenBasedAnnots, tokenBasedStrokes,
                       subtreeRegion,
                       fromLeafToLeafAfter, fromNodeToFinal) 
  where

-- Some of this might be replaced by a generic package
-- such as multirec, uniplace, emgm, ...

import Prelude (curry)

import Control.Arrow (first)
import Data.List (dropWhile, takeWhile, reverse, filter, zip, take, drop, length, splitAt)
import Data.Maybe
import Data.Monoid
#ifdef TESTING
import Test.QuickCheck
#endif

import Yi.Buffer.Basic
import Yi.Lexer.Alex
import Yi.Prelude
import Yi.Region

-- Fundamental types
type Path = [Int]
type Node t = (Path, t)

class Foldable tree => IsTree tree where
    -- | Direct subtrees of a tree
    subtrees :: tree t -> [tree t]
    subtrees = fst . uniplate
    uniplate :: tree t -> ([tree t], [tree t] -> tree t)
    emptyNode :: tree t

toksAfter :: Foldable t1 => t -> t1 a -> [a]
toksAfter _begin = allToks

allToks :: Foldable t => t a -> [a]
allToks = toList

tokAtOrBefore :: Foldable t => Point -> t (Tok t1) -> Maybe (Tok t1)
tokAtOrBefore p res = listToMaybe $ reverse $ toksInRegion (mkRegion 0 (p+1)) res

toksInRegion :: Foldable t1 => Region -> t1 (Tok t) -> [Tok t]
toksInRegion reg = takeWhile (\t -> tokBegin t <= regionEnd   reg) . dropWhile (\t -> tokEnd t < regionStart reg) . toksAfter (regionStart reg)

tokenBasedAnnots :: (Foldable t1) => (a1 -> Maybe a) -> t1 a1 -> t -> [a]
tokenBasedAnnots tta t begin = catMaybes $ fmap tta $ toksAfter begin t

tokenBasedStrokes :: (Foldable t3) => (a -> b) -> t3 a -> t -> t2 -> t1 -> [b]
tokenBasedStrokes tts t _point begin _end = fmap tts $ toksAfter begin t

-- | Prune the nodes before the given point.
-- The path is used to know which nodes we can force or not.
pruneNodesBefore :: IsTree tree => Point -> Path -> tree (Tok a) -> tree (Tok a)
pruneNodesBefore _ [] t = t
pruneNodesBefore p (x:xs) t = rebuild (left' ++ pruneNodesBefore p xs c : rights)
    where (children,rebuild) = uniplate t
          (left,c:rights) = splitAt x children
          left' = fmap replaceEmpty left
          replaceEmpty s = if getLastOffset s < p then emptyNode else s

-- | Given an approximate path to a leaf at the end of the region, return:
-- (path to leaf at the end of the region,path from focused node to the leaf, small node encompassing the region)
fromNodeToFinal :: IsTree tree => Region -> Node (tree (Tok a)) -> Node (tree (Tok a))
fromNodeToFinal r (xs,root) = 
    trace ("r = " ++ show r) $
    trace ("focused ~ " ++ show (subtreeRegion focused) ) $
    trace ("pathFromFocusedToLeaf = " ++ show focusedToLeaf) $
    trace ("pruned ~ " ++ show (subtreeRegion focused) ) $

    (xs', pruned)

    where n@(xs',_) = fromLeafToLeafAfter (regionEnd r) (xs,root)
          (_,(focusedToLeaf,focused)) = fromLeafAfterToFinal p0 n
          p0 = regionStart r
          pruned = pruneNodesBefore p0 focusedToLeaf focused

-- | Return the first element that matches the predicate, or the last of the list
-- if none matches.
firstThat :: (a -> Bool) -> [a] -> a
firstThat _ [] = error "firstThat: empty list"
firstThat _ [x] = x
firstThat p (x:xs) = if p x then x else firstThat p xs

-- | Return the element before first element that violates the predicate, or the first of the list
-- if that one violates the predicate.
lastThat :: (a -> Bool) -> [a] -> a
lastThat _ [] = error "lastThat: empty list"
lastThat p (x:xs) = if p x then work x xs else x
    where work x0 [] = x0
          work x0 (y:ys) = if p y then work y ys else x0

-- | Given a path to a node, return a path+node which 
-- node that encompasses the given node + a point before it.
fromLeafAfterToFinal :: IsTree tree => Point -> Node (tree (Tok a)) -> (Path, Node (tree (Tok a)))
fromLeafAfterToFinal p n = 
    -- trace ("reg = " ++ show (fmap (subtreeRegion . snd) nsPth)) $ 
      firstThat (\(_,(_,s)) -> getFirstOffset s <= p) ns
    where ns = (reverse (nodesOnPath n))

-- | Search the tree in pre-order starting at a given node, until finding a leaf which is at
-- or after the given point. An effort is also made to return a leaf as close as possible to @p@.
-- TODO: rename to fromLeafToLeafAt
fromLeafToLeafAfter :: IsTree tree => Point -> Node (tree (Tok a)) -> Node (tree (Tok a))
fromLeafToLeafAfter p (xs,root) = 
    trace "fromLeafToLeafAfter:" $
    trace ("xs = " ++ show xs) $
    trace ("xsValid = " ++ show xsValid) $
    trace ("p = " ++ show p) $
    trace ("leafBeforeP = " ++ show leafBeforeP) $
    trace ("leaf ~ " ++ show (subtreeRegion leaf)) $
    trace ("xs' = " ++ show xs') $
    result
    where xs' = if null candidateLeaves 
                      then []
                      else fst $ firstOrLastThat (\(_,s) -> getFirstOffset s >= p) candidateLeaves
          candidateLeaves = allLeavesRelative relChild n
          (firstOrLastThat,relChild) = if leafBeforeP then (firstThat,afterChild)
                                                      else (lastThat,beforeChild)
          (xsValid,leaf) = wkDown (xs,root)
          leafBeforeP = getFirstOffset leaf <= p
          n = (xsValid,root)
          result = (xs',root)

allLeavesRelative :: IsTree tree => (Int -> [(Int, tree a)] -> [(Int, tree a)]) -> Node (tree a) -> [Node (tree a)]
allLeavesRelative select 
   = filter (not . nullSubtree . snd) . allLeavesRelative' select . reverse . nodesAndChildIndex
     -- we remove empty subtrees because their region is [0,0].

-- | Takes a list of (node, index of already inspected child), and return all leaves
-- in this node after the said child).
allLeavesRelative' :: IsTree tree => (Int -> [(Int, tree a)] -> [(Int, tree a)]) -> [(Node (tree a), Int)] -> [Node (tree a)]
allLeavesRelative' select l 
  = [(xs ++ xs', t') | ((xs,t),c) <- l, (xs',t') <- allLeavesRelativeChild select c t] 

-- | Given a root, return all the nodes encountered along it, their
-- paths, and the index of the child which comes next.
nodesAndChildIndex :: IsTree tree => Node (tree a) -> [(Node (tree a), Int)]
nodesAndChildIndex ([],t) = [(([],t),negate 1)]
nodesAndChildIndex (x:xs, t) = case index x (subtrees t) of
    Just c' -> (([],t), x) : fmap (first $ first (x:)) (nodesAndChildIndex (xs,c'))
    Nothing -> [(([],t),negate 1)]
          
nodesOnPath :: IsTree tree => Node (tree a) -> [(Path, Node (tree a))]
nodesOnPath ([],t) = [([],([],t))]
nodesOnPath (x:xs,t) = ([],(x:xs,t)) : case index x (subtrees t) of
                           Nothing -> error "nodesOnPath: non-existent path"
                           Just c -> fmap (first (x:)) (nodesOnPath (xs,c))


beforeChild, afterChild :: Int -> [a] -> [a]

beforeChild (-1) = reverse -- (-1) indicates that all children should be taken.
beforeChild c = reverse . take (c-1)

afterChild c = drop (c+1)

-- Return all leaves after or before child depending on the relation which is given.
allLeavesRelativeChild :: IsTree tree => (Int -> [(Int, tree a)] -> [(Int, tree a)]) -> Int -> tree a -> [Node (tree a)]
allLeavesRelativeChild select c t 
    | null ts = [([], t)]
    | otherwise = [(x:xs,t') | (x,ct) <- select c (zip [0..] ts),
                   (xs, t') <- allLeavesIn select ct]
   where ts = subtrees t                      


-- | Return all leaves (with paths) inside a given root.
allLeavesIn :: (IsTree tree) => (Int -> [(Int, tree a)] -> [(Int, tree a)]) -> tree a -> [Node (tree a)]
allLeavesIn select = allLeavesRelativeChild select (-1)

-- | Return all subtrees in a tree; each element of the return list
-- contains paths to nodes. (Root is at the start of each path)
getAllPaths :: IsTree tree => tree t -> [[tree t]]
getAllPaths t = fmap (++[t]) ([] : concatMap getAllPaths (subtrees t))

goDown :: IsTree tree => Int -> tree t -> Maybe (tree t)
goDown i = index i . subtrees 

index :: Int -> [a] -> Maybe a
index _ [] = Nothing
index 0 (h:_) = Just h
index n (_:t) = index (n-1) t

walkDown :: IsTree tree => Node (tree t) -> Maybe (tree t)
walkDown ([],t) = return t
walkDown (x:xs,t) = goDown x t >>= curry walkDown xs

wkDown :: IsTree tree => Node (tree a) -> Node (tree a)
wkDown ([],t) = ([],t)
wkDown (x:xs,t) = case goDown x t of
    Nothing -> ([],t)
    Just t' -> first (x:) $ wkDown (xs,t')

-- | Search the given list, and return the last tree before the given
-- point; with path to the root. (Root is at the start of the path)
getLastPath :: IsTree tree => [tree (Tok t)] -> Point -> Maybe [tree (Tok t)]
getLastPath roots offset =
    case takeWhile ((< offset) . posnOfs . snd) allSubPathPosn of
      [] -> Nothing
      xs -> Just $ fst $ last xs
    where allSubPathPosn = [(p,posn) | root <- roots, p@(t':_) <- getAllPaths root, 
                            Just tok <- [getFirstElement t'], let posn = tokPosn tok]

-- | Return all subtrees in a tree, in preorder.
getAllSubTrees :: IsTree tree => tree t -> [tree t]
getAllSubTrees t = t : concatMap getAllSubTrees (subtrees t)

-- | Return the 1st token of a subtree.
getFirstElement :: Foldable t => t a -> Maybe a
getFirstElement tree = getFirst $ foldMap (First . Just) tree

nullSubtree :: Foldable t => t a -> Bool
nullSubtree = null . toList

getFirstTok, getLastTok :: Foldable t => t a -> Maybe a

getFirstTok = getFirstElement
getLastTok = getLastElement

-- | Return the last token of a subtree.
getLastElement :: Foldable t => t a -> Maybe a
getLastElement tree = getLast $ foldMap (Last . Just) tree

getFirstOffset, getLastOffset :: Foldable t => t (Tok t1) -> Point
getFirstOffset = maybe 0 tokBegin . getFirstTok
getLastOffset = maybe 0 tokEnd . getLastTok

subtreeRegion :: Foldable t => t (Tok t1) -> Region
subtreeRegion t = mkRegion (getFirstOffset t) (getLastOffset t) 

-- | Given a tree, return (first offset, number of lines).
getSubtreeSpan :: (Foldable tree) => tree (Tok t) -> (Point, Int)
getSubtreeSpan tree = (posnOfs $ first, lastLine - firstLine)
    where bounds@[first, _last] = fmap (tokPosn . assertJust) [getFirstElement tree, getLastElement tree]
          [firstLine, lastLine] = fmap posnLine bounds
          assertJust (Just x) = x
          assertJust _ = error "assertJust: Just expected"

-------------------------------------
-- Should be in Control.Applicative.?

sepBy :: (Alternative f) => f a -> f v -> f [a]
sepBy p s   = sepBy1 p s <|> pure []

sepBy1 :: (Alternative f) => f a -> f v -> f [a]
sepBy1 p s  = (:) <$> p <*> many (s *> p)


----------------------------------------------------
-- Testing code.

#ifdef TESTING

nodeRegion :: IsTree tree => Node (tree (Tok a)) -> Region
nodeRegion n = subtreeRegion t
    where Just t = walkDown n

data Test a = Empty | Leaf a | Bin (Test a) (Test a) deriving (Show, Eq)

instance Foldable Test where
    foldMap f (Leaf x) = f x
    foldMap f (Bin _ r) = foldMap f r <> foldMap f r

instance IsTree Test where
    uniplate (Bin l r) = ([l,r],\[l',r'] -> Bin l' r')
    uniplate t = ([],\[] -> t)
    emptyNode = Empty

type TT = Tok ()

instance Arbitrary (Test TT) where
    arbitrary = sized $ \size -> do
      arbitraryFromList [1..size+1]
    shrink (Leaf _) = []
    shrink (Bin l r) = [l,r] ++  (Bin <$> shrink l <*> pure r) ++  (Bin <$> pure l <*> shrink r)

tAt :: Point -> TT
tAt idx =  Tok () 1 (Posn (idx * 2) 0 0)

arbitraryFromList :: [Int] -> Gen (Test TT)
arbitraryFromList [] = error "arbitraryFromList expects non empty lists"
arbitraryFromList [x] = pure (Leaf (tAt (fromIntegral x)))
arbitraryFromList xs = do
  m <- choose (1,length xs - 1)
  let (l,r) = splitAt m xs
  Bin <$> arbitraryFromList l <*> arbitraryFromList r

instance Eq (Tok a) where
    x == y = tokPosn x == tokPosn y
    
instance Arbitrary Region where
    arbitrary = sized $ \size -> do
        x0 :: Int <- arbitrary
        return $ mkRegion (fromIntegral x0) (fromIntegral (x0 + size))

newtype NTTT = N (Node (Test TT)) deriving Show

instance Arbitrary NTTT where
    arbitrary = do
      t <- arbitrary
      p <- arbitraryPath t
      return $ N (p,t)

arbitraryPath :: Test t -> Gen Path
arbitraryPath (Leaf _) = return []
arbitraryPath (Bin l r) = do
  c <- choose (0,1)
  let Just n' = index c [l,r]
  (c :) <$> arbitraryPath n'

regionInside :: Region -> Gen Region
regionInside r = do
  b :: Int <- choose (fromIntegral $ regionStart r, fromIntegral $ regionEnd r)
  e :: Int <- choose (b, fromIntegral $ regionEnd r)
  return $ mkRegion (fromIntegral b) (fromIntegral e)

pointInside :: Region -> Gen Point
pointInside r = do
  p :: Int <- choose (fromIntegral $ regionStart r, fromIntegral $ regionEnd r)
  return (fromIntegral p)

prop_fromLeafAfterToFinal :: NTTT -> Property
prop_fromLeafAfterToFinal (N n) = let
    fullRegion = subtreeRegion $ snd n
 in forAll (pointInside fullRegion) $ \p -> do
   let final@(finalPath, (pathFromSubtree, finalSubtree)) = fromLeafAfterToFinal p n
       finalRegion = subtreeRegion finalSubtree
       initialRegion = nodeRegion n
       
   whenFail (do putStrLn $ "final = " ++ show final
                putStrLn $ "final reg = " ++ show finalRegion
                putStrLn $ "initialReg = " ++ show initialRegion
                putStrLn $ "p = " ++ show p
            ) 
     ((regionStart finalRegion <= p) && (initialRegion `includedRegion` finalRegion))

prop_allLeavesAfter :: NTTT -> Property
prop_allLeavesAfter (N n@(xs,t)) = do
  let after = allLeavesRelative afterChild n
  (xs',t') <- elements after
  let t'' = walkDown (xs',t)
  whenFail (do putStrLn $ "t' = " ++ show t'
               putStrLn $ "t'' = " ++ show t''
               putStrLn $ "xs' = " ++ show xs'
           ) (Just t' == t'' && xs <= xs')

prop_allLeavesBefore :: NTTT -> Property
prop_allLeavesBefore (N n@(xs,t)) = do
  let after = allLeavesRelative beforeChild n
  (xs',t') <- elements after
  let t'' = walkDown (xs',t)
  whenFail (do putStrLn $ "t' = " ++ show t'
               putStrLn $ "t'' = " ++ show t''
               putStrLn $ "xs' = " ++ show xs'
           ) (Just t' == t'' && xs' <= xs)

prop_fromNodeToLeafAfter :: NTTT -> Property
prop_fromNodeToLeafAfter (N n) = forAll (pointInside (subtreeRegion $ snd n)) $ \p -> do
   let after = fromLeafToLeafAfter p n
       afterRegion = nodeRegion after
   whenFail (do putStrLn $ "after = " ++ show after
                putStrLn $ "after reg = " ++ show afterRegion
            ) 
     (regionStart afterRegion >= p)

prop_fromNodeToFinal :: NTTT -> Property
prop_fromNodeToFinal  (N t) = forAll (regionInside (subtreeRegion $ snd t)) $ \r -> do
   let final@(finalPath, finalSubtree) = fromNodeToFinal r t
       finalRegion = subtreeRegion finalSubtree
   whenFail (do putStrLn $ "final = " ++ show final
                putStrLn $ "final reg = " ++ show finalRegion
                putStrLn $ "leaf after = " ++ show (fromLeafToLeafAfter (regionEnd r) t)
            ) $ do
     r `includedRegion` finalRegion

#endif

